{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
      "- tokenization_internlm_xcomposer2.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
      "- configuration_chartmoe.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
      "- build_mlp.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
      "- modeling_internlm2.py\n",
      "- build_mlp.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
      "- build_moe_connector.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
      "- modeling_chartmoe.py\n",
      "- modeling_internlm2.py\n",
      "- build_moe_connector.py\n",
      ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
      "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Downloading shards: 100%|██████████| 2/2 [00:00<00:00,  4.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Set max length to 4096\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.16s/it]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='2'\n",
    "import sys  \n",
    "import json\n",
    "import torch\n",
    "import numpy as np\n",
    "from PIL import Image \n",
    "from tqdm import tqdm\n",
    "import datetime\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "from datasets import load_dataset\n",
    "from chartmoe import ChartMoE_Robot\n",
    "\n",
    "mme_data = load_dataset(\"lmms-lab/MME\")['test']\n",
    "\n",
    "robot = ChartMoE_Robot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_type_dict = {\n",
    "    \"Perception\": [\n",
    "        \"existence\",\n",
    "        \"count\",\n",
    "        \"position\",\n",
    "        \"color\",\n",
    "        \"posters\",\n",
    "        \"celebrity\",\n",
    "        \"scene\",\n",
    "        \"landmark\",\n",
    "        \"artwork\",\n",
    "        \"OCR\",\n",
    "    ],\n",
    "    \"Cognition\": [\n",
    "        \"commonsense_reasoning\",\n",
    "        \"numerical_calculation\",\n",
    "        \"text_translation\",\n",
    "        \"code_reasoning\",\n",
    "    ],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_pred_ans(pred_ans):\n",
    "    \"\"\"Brought from Otter Eval\"\"\"\n",
    "    pred_ans = pred_ans.lower().strip().replace(\".\", \"\")\n",
    "    pred_label = None\n",
    "    if pred_ans in [\"yes\", \"no\"]:\n",
    "        pred_label = pred_ans\n",
    "    elif len(pred_ans) == 1:\n",
    "        if pred_ans == \"y\":\n",
    "            pred_label = \"yes\"\n",
    "        elif pred_ans == \"n\":\n",
    "            pred_label = \"no\"\n",
    "        else:\n",
    "            pred_label = \"other\"\n",
    "    else:\n",
    "        prefix_pred_ans = pred_ans[:4]\n",
    "        if \"yes\" in prefix_pred_ans:\n",
    "            pred_label = \"yes\"\n",
    "        elif \"no\" in prefix_pred_ans:\n",
    "            pred_label = \"no\"\n",
    "        else:\n",
    "            pred_label = \"other\"\n",
    "    return pred_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2374 [00:00<?, ?it/s]/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/transformers/generation/utils.py:1417: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )\n",
      "  warnings.warn(\n",
      "100%|██████████| 2374/2374 [58:05<00:00,  1.47s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for d in tqdm(mme_data):\n",
    "    image = d['image'].convert(\"RGB\")\n",
    "    question = d['question']\n",
    "    category = d['category']\n",
    "    gt_ans = d[\"answer\"].lower().strip().replace(\".\", \"\")\n",
    "\n",
    "    with torch.cuda.amp.autocast():\n",
    "        pred, _ = robot.chat(\n",
    "            image=image,\n",
    "            question=question,\n",
    "            temperature=1.0,\n",
    "            max_new_tokens=5,\n",
    "            num_beams=5,\n",
    "            do_sample=False,\n",
    "            repetition_penalty=1.0\n",
    "        )\n",
    "\n",
    "    pred_ans = parse_pred_ans(pred)\n",
    "    assert gt_ans in [\"yes\", \"no\"]\n",
    "    # assert pred_ans in [\"yes\", \"no\", \"other\"]\n",
    "\n",
    "    score = 1.0 if pred_ans == gt_ans else 0.0\n",
    "    key_name = \"mme_percetion_score\" if category in eval_type_dict[\"Perception\"] else \"mme_cognition_score\"\n",
    "\n",
    "    results.append({key_name: {\"question_id\": d[\"question_id\"], \"category\": category, \"score\": score}})\n",
    "\n",
    "with open(\"mme_results.jsonl\", 'w') as f:\n",
    "    for res in results:\n",
    "        f.write(f\"{json.dumps(res)}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2214.1313525410164\n"
     ]
    }
   ],
   "source": [
    "category2score = defaultdict(dict)\n",
    "results = [list(res.values())[0] for res in results]\n",
    "for result in results:\n",
    "    question_id = result[\"question_id\"]\n",
    "    score = result[\"score\"]\n",
    "    category = result[\"category\"]\n",
    "    if question_id not in category2score[category]:\n",
    "        category2score[category][question_id] = []\n",
    "    category2score[category][question_id].append(score)\n",
    "category2avg_score = {}\n",
    "for category, question2scores in category2score.items():\n",
    "    total_score = 0\n",
    "    for question_id, scores in question2scores.items():\n",
    "        assert len(scores) == 2, \"MME only supports pairwise evaluation\"\n",
    "        acc = sum(scores) / len(scores) * 100.0\n",
    "        acc_plus = (sum(scores) == 2) * 100.0\n",
    "        score = acc_plus + acc\n",
    "        total_score += score\n",
    "    avg_score = total_score / len(question2scores)\n",
    "    category2avg_score[category] = avg_score\n",
    "total_score = sum(category2avg_score.values())\n",
    "print(total_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'code_reasoning': 117.5,\n",
       " 'artwork': 186.25,\n",
       " 'celebrity': 163.8235294117647,\n",
       " 'numerical_calculation': 147.5,\n",
       " 'text_translation': 155.0,\n",
       " 'count': 170.0,\n",
       " 'color': 165.0,\n",
       " 'commonsense_reasoning': 140.71428571428572,\n",
       " 'position': 158.33333333333334,\n",
       " 'OCR': 125.0,\n",
       " 'landmark': 172.0,\n",
       " 'scene': 157.5,\n",
       " 'existence': 180.0,\n",
       " 'posters': 175.51020408163265}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "category2avg_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {'Perception': 1653.4170668267307, 'Cognition': 560.7142857142858})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores = defaultdict(int)\n",
    "for eval_type in eval_type_dict:\n",
    "    for category_type in eval_type_dict[eval_type]:\n",
    "        scores[eval_type] += category2avg_score[category_type]\n",
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2214.1000000000004"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1653.4 + 560.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.19 ('intern_clean')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "a726b55af14ee9f10619d25e42820b32d50f8ab305998596bdf5d4abd3695153"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
